{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Mistral.rs Python API Cookbook\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# First, install Rust: https://rustup.rs/\n",
    "%pip install mistralrs-cuda -v"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mistralrs import Runner, Which, ChatCompletionRequest\n",
    "\n",
    "runner = Runner(\n",
    "    which=Which.GGUF(\n",
    "        tok_model_id=\"mistralai/Mistral-7B-Instruct-v0.1\",\n",
    "        quantized_model_id=\"TheBloke/Mistral-7B-Instruct-v0.1-GGUF\",\n",
    "        quantized_filename=\"mistral-7b-instruct-v0.1.Q4_K_M.gguf\",\n",
    "    )\n",
    ")\n",
    "\n",
    "res = runner.send_chat_completion_request(\n",
    "    ChatCompletionRequest(\n",
    "        model=\"mistral\",\n",
    "        messages=[\n",
    "            {\"role\": \"user\", \"content\": \"Tell me a story about the Rust type system.\"}\n",
    "        ],\n",
    "        max_tokens=256,\n",
    "        presence_penalty=1.0,\n",
    "        top_p=0.1,\n",
    "        temperature=0.1,\n",
    "    )\n",
    ")\n",
    "print(res)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Lets walk through this code.\n",
    "```python\n",
    "from mistralrs import Runner, Which, ChatCompletionRequest\n",
    "```\n",
    "\n",
    "This imports the requires classes for our example. The `Runner` is a class which handles loading and running the model, which are enumerated by the `Which` class.\n",
    "\n",
    "```python\n",
    "runner = Runner(\n",
    "    which=Which.GGUF(\n",
    "        tok_model_id=\"mistralai/Mistral-7B-Instruct-v0.1\",\n",
    "        quantized_model_id=\"TheBloke/Mistral-7B-Instruct-v0.1-GGUF\",\n",
    "        quantized_filename=\"mistral-7b-instruct-v0.1.Q4_K_M.gguf\",\n",
    "        tokenizer_json=None,\n",
    "        repeat_last_n=64,\n",
    "    )\n",
    ")\n",
    "```\n",
    "\n",
    "This tells the `Runner` to actually load the model. It will use a CUDA, Metal, or CPU device depending on what `features` you set during compilation: [here](https://github.com/EricLBuehler/mistral.rs?tab=readme-ov-file#supported-accelerators).\n",
    "\n",
    "```python\n",
    "res = runner.send_chat_completion_request(\n",
    "    ChatCompletionRequest(\n",
    "        model=\"mistral\",\n",
    "        messages=[{\"role\":\"user\", \"content\":\"Tell me a story about the Rust type system.\"}],\n",
    "        max_tokens=256,\n",
    "        presence_penalty=1.0,\n",
    "        top_p=0.1,\n",
    "        temperature=0.1,\n",
    "    )\n",
    ")\n",
    "print(res)\n",
    "```\n",
    "\n",
    "Now we actually send a request! We can specify the messages just like with an OpenAI API."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mistralrs import Runner, Which, ChatCompletionRequest, Architecture\n",
    "\n",
    "runner = Runner(\n",
    "    which=Which.Plain(\n",
    "        model_id=\"mistralai/Mistral-7B-Instruct-v0.1\",\n",
    "        arch=Architecture.Mistral,\n",
    "    )\n",
    ")\n",
    "\n",
    "res = runner.send_chat_completion_request(\n",
    "    ChatCompletionRequest(\n",
    "        model=\"mistral\",\n",
    "        messages=[\n",
    "            {\"role\": \"user\", \"content\": \"Tell me a story about the Rust type system.\"}\n",
    "        ],\n",
    "        max_tokens=256,\n",
    "        presence_penalty=1.0,\n",
    "        top_p=0.1,\n",
    "        temperature=0.1,\n",
    "    )\n",
    ")\n",
    "print(res)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Lets walk through this code too and see the difference between loading a Plain model and loading a GGUF model.\n",
    "```python\n",
    "from mistralrs import Runner, Which, ChatCompletionRequest, Architecture\n",
    "```\n",
    "\n",
    "This imports the requires classes for our example. The `Runner` is a class which handles loading and running the model, which are enumerated by the `Which` class. Note that we also import the `Architecture` enum which controls the non-GGUF model's architecture.\n",
    "\n",
    "```python\n",
    "runner = Runner(\n",
    "    which=Which.Plain(\n",
    "        model_id=\"mistralai/Mistral-7B-Instruct-v0.1\",\n",
    "        arch=Architecture.Mistral,\n",
    "        tokenizer_json=None,\n",
    "        repeat_last_n=64,\n",
    "    )\n",
    ")\n",
    "```\n",
    "\n",
    "This tells the `Runner` to actually load the model as a Mistral architecture. It will use a CUDA, Metal, or CPU device depending on what `features` you set during compilation: [here](https://github.com/EricLBuehler/mistral.rs?tab=readme-ov-file#supported-accelerators).\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading a Mistral + GGUF model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mistralrs import Runner, Which\n",
    "\n",
    "runner = Runner(\n",
    "    which=Which.GGUF(\n",
    "        tok_model_id=\"mistralai/Mistral-7B-Instruct-v0.1\",\n",
    "        quantized_model_id=\"TheBloke/Mistral-7B-Instruct-v0.1-GGUF\",\n",
    "        quantized_filename=\"mistral-7b-instruct-v0.1.Q4_K_M.gguf\",\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading a plain Mistral model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mistralrs import Runner, Which, Architecture\n",
    "\n",
    "runner = Runner(\n",
    "    which=Which.Plain(\n",
    "        tok_model_id=\"mistralai/Mistral-7B-Instruct-v0.1\",\n",
    "        arch=Architecture.Mistral,\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading an X-LoRA Zephyr model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mistralrs import Runner, Which\n",
    "\n",
    "runner = Runner(\n",
    "    which=Which.XLoraGGUF(\n",
    "        tok_model_id=None,  # Automatically determine from ordering file\n",
    "        quantized_model_id=\"TheBloke/zephyr-7B-beta-GGUF\",\n",
    "        quantized_filename=\"zephyr-7b-beta.Q4_0.gguf\",\n",
    "        xlora_model_id=\"lamm-mit/x-lora\",\n",
    "        order=\"orderings/xlora-paper-ordering.json\",\n",
    "        tgt_non_granular_index=None,\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Running the Runner"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mistralrs import ChatCompletionRequest\n",
    "\n",
    "res = runner.send_chat_completion_request(\n",
    "    ChatCompletionRequest(\n",
    "        model=\"mistral\",\n",
    "        messages=[\n",
    "            {\"role\": \"user\", \"content\": \"Tell me a story about the Rust type system.\"}\n",
    "        ],\n",
    "        max_tokens=256,\n",
    "        presence_penalty=1.0,\n",
    "        top_p=0.1,\n",
    "        temperature=0.1,\n",
    "    )\n",
    ")\n",
    "print(res.choices[0].message.content)\n",
    "print(res.usage)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch-3.12",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
